import numpy as np
import torch
import torch.distributions as td
from functools import partial


class ToySampler: # a dump prior sampler to align with DataSampler
    def __init__(self, dataset):
        self.dataset = dataset

    def sample(self, batch_size):
        return self.dataset.sample(batch_size)

# ------------------------
# Toy Datasets
# ------------------------
def get_toydataset(data_name, datasize):
    return {'1d_2gaussian': MixMultiVariateNormal1D,
    '8gaussian': MixMultiVariateNormal,
    'checkerboard': CheckerBoard,
    'spiral': Spiral,
    'moon': Moon,
    '25gaussian': SquareGaussian,
    'twocircles': partial(Circles, centers=[[0,0], [0,0]], radius=[8,16], sigmas=[0.2, 0.2]),
    }.get(data_name)(datasize)


class CheckerBoard:
    def __init__(self, datasize):
        pass

    def sample(self, n):
        n = n[0]
        n_points = 3*n
        n_classes = 2
        freq = 5
        x = np.random.uniform(-(freq//2)*np.pi, (freq//2)*np.pi, size=(n_points, n_classes))
        mask = np.logical_or(np.logical_and(np.sin(x[:,0]) > 0.0, np.sin(x[:,1]) > 0.0), \
        np.logical_and(np.sin(x[:,0]) < 0.0, np.sin(x[:,1]) < 0.0))
        y = np.eye(n_classes)[1*mask]
        x0=x[:,0]*y[:,0]
        x1=x[:,1]*y[:,0]
        sample=np.concatenate([x0[...,None],x1[...,None]],axis=-1)
        sqr=np.sum(np.square(sample),axis=-1)
        idxs=np.where(sqr==0)
        sample=np.delete(sample,idxs,axis=0)
        sample=torch.Tensor(sample)
        sample=sample[0:n,:]
        return sample / 3.


class Spiral:
    def __init__(self, datasize):
        pass

    def sample(self, n):
        n = n[0]
        theta = np.sqrt(np.random.rand(n))*3*np.pi-0.5*np.pi # np.linspace(0,2*pi,100)

        r_a = theta + np.pi
        data_a = np.array([np.cos(theta)*r_a, np.sin(theta)*r_a]).T
        x_a = data_a + 0.25*np.random.randn(n,2)
        samples = np.append(x_a, np.zeros((n,1)), axis=1)
        samples = samples[:,0:2]
        return torch.Tensor(samples)


class Moon:
    def __init__(self, datasize):
        pass

    def sample(self, n):
        n = n[0]
        x = np.linspace(0, np.pi, n // 2)
        u = np.stack([np.cos(x) + .5, -np.sin(x) + .2], axis=1) * 12.
        u += 0.5*np.random.normal(size=u.shape)
        v = np.stack([np.cos(x) - .5, np.sin(x) - .2], axis=1) * 12.
        v += 0.5*np.random.normal(size=v.shape)
        x = np.concatenate([u, v], axis=0)
        return torch.Tensor(x)


class MixMultiVariateNormal1D:
    def __init__(self, datasize ,sigma=0.1):
        self.mus = [-2, 2]
        self.sigma = sigma

    def sample(self, n):
        n = n[0]
        ind_sample = n / 2
        samples=[torch.randn(int(ind_sample),1)*self.sigma + mu for mu in self.mus]
        samples=torch.cat(samples,dim=0)
        return samples


class MixMultiVariateNormal:
    def __init__(self, datasize, radius=12, num=8, sigma=0.4):

        # build mu's and sigma's
        arc = 2*np.pi/num
        xs = [np.cos(arc*idx)*radius for idx in range(num)]
        ys = [np.sin(arc*idx)*radius for idx in range(num)]
        mus = [torch.Tensor([x,y]) for x,y in zip(xs,ys)]
        dim = len(mus[0])
        sigmas = [sigma*torch.eye(dim) for _ in range(num)] 

        self.num = num
        self.dists=[
            td.multivariate_normal.MultivariateNormal(mu, sigma) for mu, sigma in zip(mus, sigmas)
        ]

    def sample(self, n):
        n = n[0]
        assert n % self.num == 0
        ind_sample = n/self.num
        samples=[dist.sample([int(ind_sample)]) for dist in self.dists]
        samples=torch.cat(samples,dim=0)
        return samples


class SquareGaussian:
    def __init__(self, datasize, num=25, sigma=0.01):

        # build mu's and sigma's
        xs = [-16]*5+[-8]*5+[0]*5+[8]*5+[16]*5
        ys = [-16,-8,0,8,16]*5
        mus = [torch.Tensor([x,y]) for x,y in zip(xs,ys)]
        dim = len(mus[0])
        sigmas = [sigma*torch.eye(dim) for _ in range(num)] 

        self.num = num
        self.dists=[
            td.multivariate_normal.MultivariateNormal(mu, sigma) for mu, sigma in zip(mus, sigmas)
        ]

    def sample(self, n):
        n = n[0]
        assert n%self.num == 0
        ind_sample = n/self.num
        samples=[dist.sample([int(ind_sample)]) for dist in self.dists]
        samples=torch.cat(samples,dim=0)
        return samples
    

class Circles:
    def __init__(self, datasize, centers, radius, sigmas):
        assert len(centers) == len(radius)
        assert  len(radius) == len(sigmas)
        self.num_circles = len(centers)        
        self.centers = centers
        self.radius = radius
        self.sigmas = sigmas
        
    def sample(self, n):
        n = n[0]
        assert n % self.num_circles == 0
        ind_sample =  int(n // self.num_circles)
        centers = torch.tensor(self.centers * ind_sample, dtype=torch.float32)
        radius = torch.tensor(self.radius * ind_sample, dtype=torch.float32)[:,None]
        sigmas = torch.tensor(self.sigmas * ind_sample, dtype=torch.float32)[:,None]
        noise = torch.randn(size=(n, 2))
        z = torch.randn(size=(n, 2))
        z = z/torch.norm(z, dim=1, keepdim=True)
        return centers + radius* z + sigmas * noise
